{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "为了进一步增强主题关键词提取的效果，可以采用多头注意力机制，并结合不同层次的注意力权重。通过这种方法，我们可以从不同角度和层次提取关键词，使得最终的关键词更加全面和准确。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "缺点：分词有问题"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np \n",
    "import pandas as pd \n",
    "from sklearn.feature_extraction.text import TfidfVectorizer\n",
    "from transformers import AutoModel, AutoTokenizer\n",
    "import torch\n",
    "from sklearn.decomposition import PCA\n",
    "import matplotlib.pyplot as plt "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/envs/py311/lib/python3.11/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
      "  warnings.warn(\n",
      "A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.\n",
      "A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.\n"
     ]
    }
   ],
   "source": [
    "# 加载预训练的Bert模型和分词器\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-chinese\")\n",
    "model = AutoModel.from_pretrained(\"bert-base-chinese\", output_attentions=True)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 数据准备：实例文档\n",
    "documents = [\n",
    "    \"机器学习是一个非常有趣的领域。\",\n",
    "    \"人工智能是未来的发展方向。\",\n",
    "    \"今天天气很好，适合出去散步。\",\n",
    "    \"我喜欢在周末去爬山。\",\n",
    "    \"深度学习模型需要大量的数据。\",\n",
    "    \"自然语言处理是人工智能的重要组成部分。\"\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 计算TF-IDF权重\n",
    "vectorizer = TfidfVectorizer()\n",
    "tfidf_matrix = vectorizer.fit_transform(documents)\n",
    "feature_names = vectorizer.get_feature_names_out()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_attention_keywords(text, tfidf_weight, top_k=3):\n",
    "    inputs = tokenizer(text, return_tensors=\"pt\", padding=True, truncation=True, max_length=128)\n",
    "    output = model(**inputs)\n",
    "\n",
    "    # 获取所有注意力权重\n",
    "    attentions = output.attentions # 返回一个包含所有层的注意力权重列表\n",
    "\n",
    "    # 初始化一个字典来存储词语的总和注意力得分\n",
    "    combined_attention_scores = np.zeros(inputs[\"input_ids\"].shape[1])\n",
    "\n",
    "    # 对每一层的每一头部的注意力权重进行累加\n",
    "    for layer_attention in attentions:\n",
    "        # 关注的是[CLS]token的注意力分布\n",
    "        cls_attention_heads = layer_attention[0, :, 0, :]\n",
    "\n",
    "        # 计算各个词语对[CLS]的注意力得分的平均值\n",
    "        cls_attention_mean = cls_attention_heads.mean(dim=0).detach().numpy()\n",
    "\n",
    "        # 累加各层各头注意力得分\n",
    "        combined_attention_scores += cls_attention_mean\n",
    "\n",
    "    # 获取输入的tokens\n",
    "    tokens = tokenizer.convert_ids_to_tokens(inputs[\"input_ids\"][0])\n",
    "\n",
    "    # 将BERT的注意力得分与TF-IDF权重结合\n",
    "    combined_scores = {}\n",
    "    for idx, token in enumerate(tokens):\n",
    "        if token in [\"[PAD]\", \"[UNK]\", \"[CLS]\", \"[SEP]\"]:\n",
    "            pass\n",
    "        tifidf_score = tfidf_weight.get(token, 0)\n",
    "        attention_score = cls_attention_mean[idx]\n",
    "        combined_scores[token] = tifidf_score * attention_score\n",
    "\n",
    "    # 按照组合得分从高到底排序\n",
    "    sorted_tokens = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)\n",
    "\n",
    "    # 提取前top_k个词语\n",
    "    Keywords = [token for token, score in sorted_tokens[:top_k] ]\n",
    "    \n",
    "    return Keywords"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "文档: 机器学习是一个非常有趣的领域。\n",
      "提取的主题关键词: ['[CLS]', '机', '器']\n",
      "\n",
      "文档: 人工智能是未来的发展方向。\n",
      "提取的主题关键词: ['[CLS]', '人', '工']\n",
      "\n",
      "文档: 今天天气很好，适合出去散步。\n",
      "提取的主题关键词: ['[CLS]', '今', '天']\n",
      "\n",
      "文档: 我喜欢在周末去爬山。\n",
      "提取的主题关键词: ['[CLS]', '我', '喜']\n",
      "\n",
      "文档: 深度学习模型需要大量的数据。\n",
      "提取的主题关键词: ['[CLS]', '深', '度']\n",
      "\n",
      "文档: 自然语言处理是人工智能的重要组成部分。\n",
      "提取的主题关键词: ['[CLS]', '自', '然']\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# 处理所有文档，提取主题关键词\n",
    "for i, doc in enumerate(documents):\n",
    "    tfidf_scores = dict(zip(feature_names, tfidf_matrix[i].toarray().flatten()))\n",
    "    keywords = extract_attention_keywords(doc, tfidf_scores)\n",
    "    print(f\"文档: {doc}\\n提取的主题关键词: {keywords}\\n\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py311",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
